import struct
from enum import Enum
import aiohttp
from typing import List, Union, Any, Optional
from PIL import Image, ImageOps
from io import BytesIO
from pydantic import BaseModel as PydanticBaseModel

class BaseModel(PydanticBaseModel):
    class Config:
        arbitrary_types_allowed = True
        
class Status(Enum):
    NOT_STARTED = "not-started"
    RUNNING = "running"
    SUCCESS = "success"
    FAILED = "failed"
    UPLOADING = "uploading"

class StreamingPrompt(BaseModel):
    workflow_api: Any
    auth_token: str
    inputs: dict[str, Union[str, bytes, Image.Image]]
    running_prompt_ids: set[str] = set()
    status_endpoint: Optional[str]
    file_upload_endpoint: Optional[str]
    
class SimplePrompt(BaseModel):
    status_endpoint: Optional[str]
    file_upload_endpoint: Optional[str]
    
    token: Optional[str]
    
    workflow_api: dict
    status: Status = Status.NOT_STARTED
    progress: set = set()
    last_updated_node: Optional[str] = None,
    uploading_nodes: set = set()
    done: bool = False
    is_realtime: bool = False,
    start_time: Optional[float] = None,

sockets = dict()
prompt_metadata: dict[str, SimplePrompt] = {}
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}

class BinaryEventTypes:
    PREVIEW_IMAGE = 1
    UNENCODED_PREVIEW_IMAGE = 2
    
max_output_id_length = 24

async def send_image(image_data, sid=None, output_id:str = None):
    max_length = max_output_id_length
    output_id = output_id[:max_length]
    padded_output_id = output_id.ljust(max_length, '\x00')
    encoded_output_id = padded_output_id.encode('ascii', 'replace')
    
    image_type = image_data[0]
    image = image_data[1]
    max_size = image_data[2]
    quality = image_data[3]
    if max_size is not None:
        if hasattr(Image, 'Resampling'):
            resampling = Image.Resampling.BILINEAR
        else:
            resampling = Image.ANTIALIAS

        image = ImageOps.contain(image, (max_size, max_size), resampling)
    type_num = 1
    if image_type == "JPEG":
        type_num = 1
    elif image_type == "PNG":
        type_num = 2
    elif image_type == "WEBP":
        type_num = 3

    bytesIO = BytesIO()
    header = struct.pack(">I", type_num)
    # 4 bytes for the type
    bytesIO.write(header)
    # 10 bytes for the output_id
    position_before = bytesIO.tell()
    bytesIO.write(encoded_output_id)
    position_after = bytesIO.tell()
    bytes_written = position_after - position_before
    print(f"Bytes written: {bytes_written}")
    
    image.save(bytesIO, format=image_type, quality=quality, compress_level=1)
    preview_bytes = bytesIO.getvalue()
    await send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
        
async def send_socket_catch_exception(function, message):
    try:
        await function(message)
    except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
        print("send error:", err)

def encode_bytes(event, data):
    if not isinstance(event, int):
        raise RuntimeError(f"Binary event types must be integers, got {event}")

    packed = struct.pack(">I", event)
    message = bytearray(packed)
    message.extend(data)
    return message

async def send_bytes(event, data, sid=None):
    message = encode_bytes(event, data)
    
    print("sending image to ", event, sid)

    if sid is None:
        _sockets = list(sockets.values())
        for ws in _sockets:
            await send_socket_catch_exception(ws.send_bytes, message)
    elif sid in sockets:
        await send_socket_catch_exception(sockets[sid].send_bytes, message)